import torch
import numpy as np
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.renderer import FoVPerspectiveCameras

def fov_camera_collate_fn(batch):
    batch_cameras = FoVPerspectiveCameras(
        R = torch.stack([item['target_camera'].R for item in batch]), 
        T = torch.stack([item['target_camera'].T for item in batch]), 
        znear = torch.stack([item['target_camera'].znear for item in batch]),
        zfar = torch.stack([item['target_camera'].zfar for item in batch]),
        aspect_ratio = torch.stack([item['target_camera'].aspect_ratio for item in batch]),
        fov = torch.stack([item['target_camera'].fov for item in batch])
    )
    batch_images = torch.stack([item['target_image'] for item in batch])
    out = {
        'cameras': batch_cameras,
        'target_images': batch_images,
    }
    if 'target_silhouette' in batch[0]:
        batch_silhouettes = torch.stack([item['target_silhouette'] for item in batch])
        out['target_silhouettes'] = batch_silhouettes
    return out

def perspective_camera_collate_fn(batch):
    batch_cameras = PerspectiveCameras(
        R = torch.cat([item['target_camera'].R for item in batch], dim=0), 
        T = torch.cat([item['target_camera'].T for item in batch], dim=0), 
        focal_length=torch.cat([item['target_camera'].focal_length for item in batch], dim=0),
        principal_point=torch.cat([item['target_camera'].principal_point for item in batch], dim=0),
    )
    batch_images = torch.stack([item['target_image'] for item in batch]) if batch[0]['target_image'] is not None else None
    out = {
        'cameras': batch_cameras,
        'target_images': batch_images,
    }
    if 'target_silhouette' in batch[0]:
        batch_silhouettes = torch.stack([item['target_silhouette'] for item in batch])
        out['target_silhouettes'] = batch_silhouettes
    return out